/*
 * Vulkan Samples
 *
 * Copyright (C) 2015-2016 Valve Corporation
 * Copyright (C) 2015-2016 Valve Corporation
 * Copyright (C) 2015-2016 Google, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
VULKAN_SAMPLE_SHORT_DESCRIPTION
Demonstrate how to use SPIR-V shaders with inline assembly.
*/

#include <util_init.hpp>
#include <assert.h>
#include <string.h>
#include <cstdlib>
#include "cube_data.h"
#include "spirv-tools/libspirv.h"

// clang-format off
// This sample is based on the template, but instead of using inline GLSL and calls to
// glslang to generate SPIR-V binaries, we use inline assembly and pass it to the
// SPIRV-Tools assembler.  This is one of many ways to generate SPIR-V binaries,
// which is the only shader representation accepted by Vulkan.

// The following inline SPIR-V assembly was generated by:
//   Populating template.vert and template.frag with contents of inlined GLSL from template sample
//   Running the following commands on Linux:
//     ./glslang/build/Standalone/bin/glslangValidator -V ./API-Samples/template.vert -o ./API-Samples/template.vert.spv
//     ./glslang/build/Standalone/bin/glslangValidator -V ./API-Samples/template.frag -o ./API-Samples/template.frag.spv
//     ./spirv-tools/build/spirv-dis ./API-Samples/template.vert.spv | sed -e 's/\"/\\\"/g' -e 's/.*/\"&\\n\"/'
//     ./spirv-tools/build/spirv-dis ./API-Samples/template.frag.spv | sed -e 's/\"/\\\"/g' -e 's/.*/\"&\\n\"/'

const std::string vertexSPIRV =
        "; SPIR-V\n"
        "; Version: 1.0\n"
        "; Generator: Khronos Glslang Reference Front End; 1\n"
        "; Bound: 35\n"
        "; Schema: 0\n"
        "               OpCapability Shader\n"
        "          %1 = OpExtInstImport \"GLSL.std.450\"\n"
        "               OpMemoryModel Logical GLSL450\n"
        "               OpEntryPoint Vertex %main \"main\" %texcoord %inTexCoords %_ %pos\n"
        "               OpSource GLSL 400\n"
        "               OpSourceExtension \"GL_ARB_separate_shader_objects\"\n"
        "               OpSourceExtension \"GL_ARB_shading_language_420pack\"\n"
        "               OpName %main \"main\"\n"
        "               OpName %texcoord \"texcoord\"\n"
        "               OpName %inTexCoords \"inTexCoords\"\n"
        "               OpName %gl_PerVertex \"gl_PerVertex\"\n"
        "               OpMemberName %gl_PerVertex 0 \"gl_Position\"\n"
        "               OpMemberName %gl_PerVertex 1 \"gl_PointSize\"\n"
        "               OpMemberName %gl_PerVertex 2 \"gl_ClipDistance\"\n"
        "               OpName %_ \"\"\n"
        "               OpName %buf \"buf\"\n"
        "               OpMemberName %buf 0 \"mvp\"\n"
        "               OpName %ubuf \"ubuf\"\n"
        "               OpName %pos \"pos\"\n"
        "               OpDecorate %texcoord Location 0\n"
        "               OpDecorate %inTexCoords Location 1\n"
        "               OpMemberDecorate %gl_PerVertex 0 BuiltIn Position\n"
        "               OpMemberDecorate %gl_PerVertex 1 BuiltIn PointSize\n"
        "               OpMemberDecorate %gl_PerVertex 2 BuiltIn ClipDistance\n"
        "               OpDecorate %gl_PerVertex Block\n"
        "               OpMemberDecorate %buf 0 ColMajor\n"
        "               OpMemberDecorate %buf 0 Offset 0\n"
        "               OpMemberDecorate %buf 0 MatrixStride 16\n"
        "               OpDecorate %buf Block\n"
        "               OpDecorate %ubuf DescriptorSet 0\n"
        "               OpDecorate %ubuf Binding 0\n"
        "               OpDecorate %pos Location 0\n"
        "       %void = OpTypeVoid\n"
        "          %3 = OpTypeFunction %void\n"
        "      %float = OpTypeFloat 32\n"
        "    %v2float = OpTypeVector %float 2\n"
        "%_ptr_Output_v2float = OpTypePointer Output %v2float\n"
        "   %texcoord = OpVariable %_ptr_Output_v2float Output\n"
        "%_ptr_Input_v2float = OpTypePointer Input %v2float\n"
        "%inTexCoords = OpVariable %_ptr_Input_v2float Input\n"
        "    %v4float = OpTypeVector %float 4\n"
        "       %uint = OpTypeInt 32 0\n"
        "     %uint_1 = OpConstant %uint 1\n"
        "%_arr_float_uint_1 = OpTypeArray %float %uint_1\n"
        "%gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_1\n"
        "%_ptr_Output_gl_PerVertex = OpTypePointer Output %gl_PerVertex\n"
        "          %_ = OpVariable %_ptr_Output_gl_PerVertex Output\n"
        "        %int = OpTypeInt 32 1\n"
        "      %int_0 = OpConstant %int 0\n"
        "%mat4v4float = OpTypeMatrix %v4float 4\n"
        "        %buf = OpTypeStruct %mat4v4float\n"
        "%_ptr_Uniform_buf = OpTypePointer Uniform %buf\n"
        "       %ubuf = OpVariable %_ptr_Uniform_buf Uniform\n"
        "%_ptr_Uniform_mat4v4float = OpTypePointer Uniform %mat4v4float\n"
        "%_ptr_Input_v4float = OpTypePointer Input %v4float\n"
        "        %pos = OpVariable %_ptr_Input_v4float Input\n"
        "%_ptr_Output_v4float = OpTypePointer Output %v4float\n"
        "       %main = OpFunction %void None %3\n"
        "          %5 = OpLabel\n"
        "         %12 = OpLoad %v2float %inTexCoords\n"
        "               OpStore %texcoord %12\n"
        "         %27 = OpAccessChain %_ptr_Uniform_mat4v4float %ubuf %int_0\n"
        "         %28 = OpLoad %mat4v4float %27\n"
        "         %31 = OpLoad %v4float %pos\n"
        "         %32 = OpMatrixTimesVector %v4float %28 %31\n"
        "         %34 = OpAccessChain %_ptr_Output_v4float %_ %int_0\n"
        "               OpStore %34 %32\n"
        "               OpReturn\n"
        "               OpFunctionEnd\n";

const std::string fragmentSPIRV =
        "; SPIR-V\n"
        "; Version: 1.0\n"
        "; Generator: Khronos Glslang Reference Front End; 1\n"
        "; Bound: 21\n"
        "; Schema: 0\n"
        "               OpCapability Shader\n"
        "          %1 = OpExtInstImport \"GLSL.std.450\"\n"
        "               OpMemoryModel Logical GLSL450\n"
        "               OpEntryPoint Fragment %4 \"main\" %9 %17\n"
        "               OpExecutionMode %4 OriginUpperLeft\n"
        "               OpSource GLSL 400\n"
        "               OpSourceExtension \"GL_ARB_separate_shader_objects\"\n"
        "               OpSourceExtension \"GL_ARB_shading_language_420pack\"\n"
        "               OpName %4 \"main\"\n"
        "               OpName %9 \"outColor\"\n"
        "               OpName %13 \"tex\"\n"
        "               OpName %17 \"texcoord\"\n"
        "               OpDecorate %9 Location 0\n"
        "               OpDecorate %13 DescriptorSet 0\n"
        "               OpDecorate %13 Binding 1\n"
        "               OpDecorate %17 Location 0\n"
        "          %2 = OpTypeVoid\n"
        "          %3 = OpTypeFunction %2\n"
        "          %6 = OpTypeFloat 32\n"
        "          %7 = OpTypeVector %6 4\n"
        "          %8 = OpTypePointer Output %7\n"
        "          %9 = OpVariable %8 Output\n"
        "         %10 = OpTypeImage %6 2D 0 0 0 1 Unknown\n"
        "         %11 = OpTypeSampledImage %10\n"
        "         %12 = OpTypePointer UniformConstant %11\n"
        "         %13 = OpVariable %12 UniformConstant\n"
        "         %15 = OpTypeVector %6 2\n"
        "         %16 = OpTypePointer Input %15\n"
        "         %17 = OpVariable %16 Input\n"
        "         %19 = OpConstant %6 0\n"
        "          %4 = OpFunction %2 None %3\n"
        "          %5 = OpLabel\n"
        "         %14 = OpLoad %11 %13\n"
        "         %18 = OpLoad %15 %17\n"
        "         %20 = OpImageSampleExplicitLod %7 %14 %18 Lod %19\n"
        "               OpStore %9 %20\n"
        "               OpReturn\n"
        "               OpFunctionEnd\n";

// clang-format on
int sample_main(int argc, char *argv[]) {
    VkResult U_ASSERT_ONLY res;
    struct sample_info info = {};
    char sample_title[] = "SPIR-V Assembly";
    const bool depthPresent = true;

    process_command_line_args(info, argc, argv);
    init_global_layer_properties(info);
    init_instance_extension_names(info);
    init_device_extension_names(info);
    init_instance(info, sample_title);
    init_enumerate_device(info);
    init_window_size(info, 500, 500);
    init_connection(info);
    init_window(info);
    init_swapchain_extension(info);
    init_device(info);
    init_command_pool(info);
    init_command_buffer(info);
    execute_begin_command_buffer(info);
    init_device_queue(info);
    init_swap_chain(info);
    init_depth_buffer(info);
    init_texture(info);
    init_uniform_buffer(info);
    init_descriptor_and_pipeline_layouts(info, true);
    init_renderpass(info, depthPresent);

    /* VULKAN_KEY_START */

    // Init the assembler context
    spv_context spvContext = spvContextCreate(SPV_ENV_VULKAN_1_0);

    // Convert the vertex assembly into binary format
    spv_binary vertexBinary = {};
    spv_diagnostic vertexDiag = {};
    spv_result_t vertexResult = spvTextToBinary(spvContext, vertexSPIRV.c_str(), vertexSPIRV.length(), &vertexBinary, &vertexDiag);
    if (vertexDiag) {
        printf("Diagnostic info from vertex shader:\n");
        spvDiagnosticPrint(vertexDiag);
    }
    assert(vertexResult == SPV_SUCCESS);

    // Convert the fragment assembly into binary format
    spv_binary fragmentBinary = {};
    spv_diagnostic fragmentDiag = {};
    spv_result_t fragmentResult =
        spvTextToBinary(spvContext, fragmentSPIRV.c_str(), fragmentSPIRV.length(), &fragmentBinary, &fragmentDiag);
    if (fragmentDiag) {
        printf("Diagnostic info from fragment shader:\n");
        spvDiagnosticPrint(fragmentDiag);
    }
    assert(fragmentResult == SPV_SUCCESS);

    info.shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
    info.shaderStages[0].pNext = NULL;
    info.shaderStages[0].pSpecializationInfo = NULL;
    info.shaderStages[0].flags = 0;
    info.shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
    info.shaderStages[0].pName = "main";
    VkShaderModuleCreateInfo moduleCreateInfo;
    moduleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
    moduleCreateInfo.pNext = NULL;
    moduleCreateInfo.flags = 0;
    // Use wordCount and code pointers from the spv_binary
    moduleCreateInfo.codeSize = vertexBinary->wordCount * sizeof(unsigned int);
    moduleCreateInfo.pCode = vertexBinary->code;
    res = vkCreateShaderModule(info.device, &moduleCreateInfo, NULL, &info.shaderStages[0].module);
    assert(res == VK_SUCCESS);

    info.shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
    info.shaderStages[1].pNext = NULL;
    info.shaderStages[1].pSpecializationInfo = NULL;
    info.shaderStages[1].flags = 0;
    info.shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
    info.shaderStages[1].pName = "main";
    moduleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
    moduleCreateInfo.pNext = NULL;
    moduleCreateInfo.flags = 0;
    // Use wordCount and code pointers from the spv_binary
    moduleCreateInfo.codeSize = fragmentBinary->wordCount * sizeof(unsigned int);
    moduleCreateInfo.pCode = fragmentBinary->code;
    res = vkCreateShaderModule(info.device, &moduleCreateInfo, NULL, &info.shaderStages[1].module);
    assert(res == VK_SUCCESS);

    // Clean up the diagnostics
    spvDiagnosticDestroy(vertexDiag);
    spvDiagnosticDestroy(fragmentDiag);

    // Clean up the assembler context
    spvContextDestroy(spvContext);

    /* VULKAN_KEY_END */

    init_framebuffers(info, depthPresent);
    init_vertex_buffer(info, g_vb_texture_Data, sizeof(g_vb_texture_Data), sizeof(g_vb_texture_Data[0]), true);
    init_descriptor_pool(info, true);
    init_descriptor_set(info, true);
    init_pipeline_cache(info);
    init_pipeline(info, depthPresent);
    init_presentable_image(info);

    VkClearValue clear_values[2];
    init_clear_color_and_depth(info, clear_values);

    VkRenderPassBeginInfo rp_begin;
    init_render_pass_begin_info(info, rp_begin);
    rp_begin.clearValueCount = 2;
    rp_begin.pClearValues = clear_values;

    vkCmdBeginRenderPass(info.cmd, &rp_begin, VK_SUBPASS_CONTENTS_INLINE);

    vkCmdBindPipeline(info.cmd, VK_PIPELINE_BIND_POINT_GRAPHICS, info.pipeline);
    vkCmdBindDescriptorSets(info.cmd, VK_PIPELINE_BIND_POINT_GRAPHICS, info.pipeline_layout, 0, NUM_DESCRIPTOR_SETS,
                            info.desc_set.data(), 0, NULL);

    const VkDeviceSize offsets[1] = {0};
    vkCmdBindVertexBuffers(info.cmd, 0, 1, &info.vertex_buffer.buf, offsets);

    init_viewports(info);
    init_scissors(info);

    vkCmdDraw(info.cmd, 12 * 3, 1, 0, 0);
    vkCmdEndRenderPass(info.cmd);
    res = vkEndCommandBuffer(info.cmd);
    assert(res == VK_SUCCESS);

    VkFence drawFence = {};
    init_fence(info, drawFence);
    VkPipelineStageFlags pipe_stage_flags = VK_PIPELINE_STAGE_COLOR_ATTACHMENT_OUTPUT_BIT;
    VkSubmitInfo submit_info = {};
    init_submit_info(info, submit_info, pipe_stage_flags);

    /* Queue the command buffer for execution */
    res = vkQueueSubmit(info.graphics_queue, 1, &submit_info, drawFence);
    assert(res == VK_SUCCESS);

    /* Now present the image in the window */
    VkPresentInfoKHR present = {};
    init_present_info(info, present);

    /* Make sure command buffer is finished before presenting */
    do {
        res = vkWaitForFences(info.device, 1, &drawFence, VK_TRUE, FENCE_TIMEOUT);
    } while (res == VK_TIMEOUT);
    assert(res == VK_SUCCESS);
    res = vkQueuePresentKHR(info.present_queue, &present);
    assert(res == VK_SUCCESS);

    wait_seconds(1);
    if (info.save_images) write_ppm(info, "spirv_assembly");

    vkDestroyFence(info.device, drawFence, NULL);
    vkDestroySemaphore(info.device, info.imageAcquiredSemaphore, NULL);
    destroy_pipeline(info);
    destroy_pipeline_cache(info);
    destroy_textures(info);
    destroy_descriptor_pool(info);
    destroy_vertex_buffer(info);
    destroy_framebuffers(info);
    destroy_shaders(info);
    destroy_renderpass(info);
    destroy_descriptor_and_pipeline_layouts(info);
    destroy_uniform_buffer(info);
    destroy_depth_buffer(info);
    destroy_swap_chain(info);
    destroy_command_buffer(info);
    destroy_command_pool(info);
    destroy_device(info);
    destroy_window(info);
    destroy_instance(info);
    return 0;
}
